#encoding=utf-8

from data import DataProcess
from sklearn.metrics import confusion_matrix

if __name__=="__main__":
    dp_test = DataProcess("/root/autodl-tmp/fakeddit/test.tsv","/root/autodl-tmp/fakeddit")
    print("model load success")
    test_text_matrix, test_image_matrix, test_label = dp_test.getdata()

    from tensorflow.keras.models import load_model
    model = load_model('/root/autodl-tmp/checkpoints_polity/dense_MM_model.pt')
    
    predictions = model.predict([test_text_matrix,test_image_matrix],batch_size = 32)
    predictions = predictions.argmax(axis=1)

    confusion = confusion_matrix(test_label,predictions)
    print(confusion)

